# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional, Union

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal


@dataclass
class BuildConfig:
    # TODO: Use `tensorrt_llm.builder.BuildConfig` when we switch to use Pytest.
    max_input_len: int = 256
    max_seq_len: int = 512
    opt_batch_size: int = 8
    max_batch_size: int = 8
    max_beam_width: int = 1
    max_num_tokens: Optional[int] = None
    opt_num_tokens: Optional[int] = None
    max_encoder_input_len: int = 1  # for enc-dec DecoderModel


@dataclass
class ModelConfig:
    num_layers: int
    num_heads: int
    hidden_size: int
    vocab_size: int
    hidden_act: Optional[str]
    n_positions: int
    num_kv_heads: Optional[int] = None
    inter_size: Optional[int] = None
    rotary_dim: Optional[int] = None
    type_vocab_size: Optional[int] = None
    pre_norm: Optional[bool] = None
    do_layer_norm_before: Optional[bool] = None
    enable_context_fmha: bool = True
    multi_block_mode: bool = True
    # The enum name of PositionEmbeddingType
    # None means using the model family's default value defined in the ctor
    position_embedding_type: str = None
    # Only when position embedding is RoPE, this value makes sense, make
    # default value to be None, not the others to prevent misuse
    rotary_pct: Optional[float] = None
    bias: bool = True
    quantization: Optional[str] = None
    moe_num_experts: int = 0
    moe_top_k: int = 0
    use_alibi: bool = None
    remove_input_padding: bool = None
    parallel_attention: bool = None
    new_decoder_architecture: bool = None
    state_size: int = 0
    state_dtype: Optional[str] = ""
    conv_kernel: int = 0
    layer_types: List[str] = field(default_factory=list)
    rnn_hidden_size: int = 0
    rnn_head_size: int = 0
    rnn_conv_dim_size: int = 0
    logits_soft_cap: float = 0.0
    use_bias: bool = None
    mamba_version: str = 'Mamba1'
    ssm_rmsnorm: bool = True
    ngroups: int = 1
    chunk_size: int = 256


@dataclass
class EncDecModelConfig:
    num_layers: int
    num_heads: int
    hidden_size: int
    vocab_size: int
    hidden_act: Optional[str]
    n_positions: int = 0
    num_decoder_layers: Optional[int] = None
    head_size: Optional[int] = None
    ffn_hidden_size: Optional[int] = None
    num_buckets: int = 0
    max_distance: int = 0
    has_embedding_scale: bool = False
    normalize_before: Optional[bool] = None
    n_mels: Optional[int] = None
    skip_cross_kv: bool = False
    use_implicit_relative_attention: Optional[bool] = False
    decoder_start_token_id: Optional[int] = None

    def __post_init__(self) -> None:
        assert self.head_size is not None
        assert self.ffn_hidden_size is not None


@dataclass
class Config:
    name: str
    family: str
    benchmark_type: Literal["gpt", "bert", "enc_dec"]
    build_config: BuildConfig
    model_config: ModelConfig


_allowed_configs = {
    "gpt_350m":
    Config(name="gpt_350m",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=1024,
           )),
    "gpt_1.5b":
    Config(name="gpt_1.5b",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=25,
               hidden_size=1600,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=1024,
           )),
    "gpt_175b":
    Config(name="gpt_175b",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=96,
               num_heads=96,
               hidden_size=12288,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=2048,
           )),
    "gpt_350m_moe":
    Config(name="gpt_350m_moe",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=1024,
               moe_num_experts=8,
               moe_top_k=1,
           )),
    "gpt_350m_sq_per_tensor":
    Config(name="gpt_350m_sq_per_tensor",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=1024,
               quantization="int8_sq_per_tensor",
           )),
    "gpt_350m_sq_per_token_channel":
    Config(name="gpt_350m_sq_per_token_channel",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=51200,
               hidden_act='gelu',
               n_positions=1024,
               quantization="int8_sq_per_token_channel",
           )),
    "gpt_next_2b":
    Config(name="gpt_next_2b",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=2048,
               vocab_size=256000,
               hidden_act='swiglu',
               n_positions=1024,
               position_embedding_type='rope_gpt_neox',
               rotary_pct=0.5,
               bias=False,
           )),
    "opt_350m":
    Config(name="opt_350m",
           family="opt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=50272,
               hidden_act='relu',
               n_positions=2048,
               pre_norm=False,
               do_layer_norm_before=False,
           )),
    "opt_2.7b":
    Config(name="opt_2.7b",
           family="opt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=2560,
               vocab_size=50272,
               hidden_act='relu',
               n_positions=2048,
               pre_norm=False,
               do_layer_norm_before=True,
           )),
    "opt_6.7b":
    Config(name="opt_6.7b",
           family="opt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=50272,
               hidden_act='relu',
               n_positions=2048,
               pre_norm=False,
               do_layer_norm_before=True,
           )),
    "opt_30b":
    Config(name="opt_30b",
           family="opt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=56,
               hidden_size=7168,
               vocab_size=50272,
               hidden_act='relu',
               n_positions=2048,
               pre_norm=False,
               do_layer_norm_before=True,
           )),
    "opt_66b":
    Config(name="opt_66b",
           family="opt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=64,
               num_heads=72,
               hidden_size=9216,
               vocab_size=50272,
               hidden_act='relu',
               n_positions=2048,
               pre_norm=True,
               do_layer_norm_before=True,
           )),
    "starcoder_15.5b":
    Config(name="starcoder_15.5b",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=48,
               num_kv_heads=1,
               hidden_size=6144,
               vocab_size=49152,
               hidden_act='gelu',
               n_positions=8192,
           )),
    "starcoder2_3b":
    Config(name="starcoder2_3b",
           family="gpt",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=30,
               num_heads=24,
               num_kv_heads=2,
               hidden_size=3072,
               vocab_size=49152,
               hidden_act='gelu',
               n_positions=16384,
               position_embedding_type='rope_gpt_neox',
               rotary_pct=1.0,
           )),
    "llama_7b":
    Config(name="llama_7b",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=11008,
           )),
    "llama_13b":
    Config(name="llama_13b",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=40,
               hidden_size=5120,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=13824,
           )),
    "llama_30b":
    Config(name="llama_30b",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=60,
               num_heads=52,
               hidden_size=6656,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=2048,
               inter_size=17920,
           )),
    "llama_70b":
    Config(name="llama_70b",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=80,
               num_heads=64,
               num_kv_heads=8,
               hidden_size=8192,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=28672,
           )),
    "llama_70b_long_context":
    Config(name="llama_70b_long_context",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=16,
               max_input_len=8000,
               max_seq_len=8200,
           ),
           model_config=ModelConfig(
               num_layers=80,
               num_heads=64,
               num_kv_heads=8,
               hidden_size=8192,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=28672,
               multi_block_mode=True,
           )),
    "llama_70b_long_generation":
    Config(name="llama_70b_long_generation",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=200,
               max_seq_len=16584,
           ),
           model_config=ModelConfig(
               num_layers=80,
               num_heads=64,
               num_kv_heads=8,
               hidden_size=8192,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=28672,
               multi_block_mode=True,
           )),
    "llama_70b_sq_per_tensor":
    Config(name="llama_70b_sq_per_tensor",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=80,
               num_heads=64,
               num_kv_heads=8,
               hidden_size=8192,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=28672,
               quantization="int8_sq_per_tensor",
           )),
    "gptj_6b":
    Config(name="gptj_6b",
           family="gptj",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=28,
               num_heads=16,
               hidden_size=4096,
               vocab_size=50401,
               hidden_act='gelu',
               n_positions=1024,
               rotary_dim=64,
           )),
    "gptneox_20b":
    Config(name="gptneox_20b",
           family="gptneox",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=16,
               max_input_len=512,
               max_seq_len=1024,
           ),
           model_config=ModelConfig(
               num_layers=44,
               num_heads=64,
               hidden_size=6144,
               vocab_size=50432,
               hidden_act='gelu',
               n_positions=2048,
               rotary_dim=24,
           )),
    "chatglm_6b":
    Config(name="chatglm_6b",
           family="chatglm",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=28,
               num_heads=32,
               num_kv_heads=32,
               hidden_size=4096,
               inter_size=16384,
               vocab_size=130528,
               hidden_act='gelu',
               n_positions=2048,
               remove_input_padding=False,
           )),
    "chatglm2_6b":
    Config(name="chatglm2_6b",
           family="chatglm2",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=28,
               num_heads=32,
               num_kv_heads=2,
               hidden_size=4096,
               inter_size=13696,
               vocab_size=65024,
               hidden_act='swiglu',
               n_positions=2048,
               remove_input_padding=False,
           )),
    "chatglm3_6b":
    Config(name="chatglm3_6b",
           family="chatglm3",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=28,
               num_heads=32,
               num_kv_heads=2,
               hidden_size=4096,
               inter_size=13696,
               vocab_size=65024,
               hidden_act='swiglu',
               n_positions=2048,
               remove_input_padding=False,
           )),
    "glm_10b":
    Config(name="glm_10b",
           family="glm",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=1024,
               max_seq_len=1280,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=64,
               num_kv_heads=64,
               hidden_size=4096,
               inter_size=16384,
               vocab_size=50304,
               hidden_act='gelu',
               n_positions=1024,
               remove_input_padding=False,
           )),
    "bloom_560m":
    Config(name="bloom_560m",
           family="bloom",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=32,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=250880,
               hidden_act=None,
               n_positions=2048,
           )),
    "bloom_176b":
    Config(name="bloom_176b",
           family="bloom",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=8,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=70,
               num_heads=112,
               hidden_size=14336,
               vocab_size=250880,
               hidden_act=None,
               n_positions=2048,
           )),
    "bert_base":
    Config(name="bert_base",
           family="bert",
           benchmark_type="bert",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=512,
           ),
           model_config=ModelConfig(
               num_layers=12,
               num_heads=12,
               hidden_size=768,
               vocab_size=30522,
               type_vocab_size=2,
               hidden_act='gelu',
               n_positions=1024,
               enable_context_fmha=False,
           )),
    "bert_large":
    Config(name="bert_large",
           family="bert",
           benchmark_type="bert",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=16,
               hidden_size=1024,
               vocab_size=30522,
               type_vocab_size=2,
               hidden_act='gelu',
               n_positions=1024,
               enable_context_fmha=False,
           )),
    "roberta_base":
    Config(name="roberta_base",
           family="roberta",
           benchmark_type="bert",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
           ),
           model_config=ModelConfig(
               num_layers=12,
               num_heads=12,
               hidden_size=768,
               vocab_size=50265,
               type_vocab_size=1,
               hidden_act='gelu',
               n_positions=1024,
               enable_context_fmha=False,
           )),
    "falcon_rw_1b":
    Config(name="falcon_rw_1b",
           family="falcon",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=256,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=32,
               hidden_size=2048,
               vocab_size=50304,
               hidden_act='gelu',
               n_positions=2048,
               bias=True,
               use_alibi=True,
               parallel_attention=False,
               new_decoder_architecture=False,
           )),
    "falcon_7b":
    Config(name="falcon_7b",
           family="falcon",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=71,
               num_kv_heads=1,
               hidden_size=4544,
               vocab_size=65024,
               hidden_act='gelu',
               n_positions=2048,
               bias=False,
               use_alibi=False,
               parallel_attention=True,
               new_decoder_architecture=False,
           )),
    "falcon_40b":
    Config(name="falcon_40b",
           family="falcon",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=60,
               num_heads=128,
               num_kv_heads=8,
               hidden_size=8192,
               vocab_size=65024,
               hidden_act='gelu',
               n_positions=2048,
               bias=False,
               use_alibi=False,
               parallel_attention=True,
               new_decoder_architecture=True,
           )),
    "falcon_180b":
    Config(name="falcon_180b",
           family="falcon",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=8,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=80,
               num_heads=232,
               num_kv_heads=8,
               hidden_size=14848,
               vocab_size=65024,
               hidden_act='gelu',
               n_positions=2048,
               bias=False,
               use_alibi=False,
               parallel_attention=True,
               new_decoder_architecture=True,
           )),
    "t5_small":
    Config(name="t5_small",
           family="t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=6,
               num_heads=8,
               head_size=64,
               ffn_hidden_size=2048,
               hidden_size=512,
               vocab_size=32128,
               hidden_act="relu",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "t5_base":
    Config(name="t5_base",
           family="t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=12,
               num_heads=12,
               head_size=64,
               ffn_hidden_size=3072,
               hidden_size=768,
               vocab_size=32128,
               hidden_act="relu",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "t5_large":
    Config(name="t5_large",
           family="t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_heads=16,
               head_size=64,
               ffn_hidden_size=4096,
               hidden_size=1024,
               vocab_size=32128,
               hidden_act="relu",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "t5_3b":
    Config(name="t5_3b",
           family="t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_heads=32,
               head_size=128,
               ffn_hidden_size=16384,
               hidden_size=1024,
               vocab_size=32128,
               hidden_act="relu",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "t5_11b":
    Config(name="t5_11b",
           family="t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_heads=128,
               head_size=128,
               ffn_hidden_size=65536,
               hidden_size=1024,
               vocab_size=32128,
               hidden_act="relu",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "flan_t5_small":
    Config(name="flan_t5_small",
           family="flan_t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=8,
               num_decoder_layers=8,
               num_heads=6,
               head_size=64,
               ffn_hidden_size=1024,
               hidden_size=512,
               vocab_size=32128,
               hidden_act="gelu_new",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "flan_t5_base":
    Config(name="flan_t5_base",
           family="flan_t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=12,
               num_decoder_layers=12,
               num_heads=12,
               head_size=64,
               ffn_hidden_size=2048,
               hidden_size=768,
               vocab_size=32128,
               hidden_act="gelu_new",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "flan_t5_large":
    Config(name="flan_t5_large",
           family="flan_t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_decoder_layers=24,
               num_heads=16,
               head_size=64,
               ffn_hidden_size=2816,
               hidden_size=1024,
               vocab_size=32128,
               hidden_act="gelu_new",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "flan_t5_xl":
    Config(name="flan_t5_xl",
           family="flan_t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_decoder_layers=24,
               num_heads=32,
               head_size=64,
               ffn_hidden_size=5120,
               hidden_size=2048,
               vocab_size=32128,
               hidden_act="gelu_new",
               n_positions=512,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "flan_t5_xxl":
    Config(name="flan_t5_xxl",
           family="flan_t5",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=24,
               num_decoder_layers=24,
               num_heads=64,
               head_size=64,
               ffn_hidden_size=10240,
               hidden_size=4096,
               vocab_size=32128,
               hidden_act="gelu_new",
               n_positions=0,
               num_buckets=32,
               max_distance=128,
               decoder_start_token_id=0,
           )),
    "bart_large_cnn":
    Config(name="bart_large_cnn",
           family="bart",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=12,
               num_decoder_layers=12,
               num_heads=16,
               head_size=64,
               ffn_hidden_size=4096,
               hidden_size=1024,
               vocab_size=50265,
               hidden_act="gelu",
               n_positions=1024,
               num_buckets=32,
               has_embedding_scale=False,
               normalize_before=False,
               decoder_start_token_id=2,
           )),
    "mbart_large_50_many_to_one_mmt":
    Config(name="mbart_large_50_many_to_one_mmt",
           family="bart",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1024,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=12,
               num_decoder_layers=12,
               num_heads=16,
               head_size=64,
               ffn_hidden_size=4096,
               hidden_size=1024,
               vocab_size=250054,
               hidden_act="relu",
               n_positions=1024,
               has_embedding_scale=True,
               normalize_before=True,
               decoder_start_token_id=2,
           )),
    "baichuan_7b":
    Config(name="baichuan_7b",
           family="baichuan",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=64000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=11008,
           )),
    "baichuan2_7b_chat":
    Config(name="baichuan2_7b_chat",
           family="baichuan",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=125696,
               hidden_act='silu',
               n_positions=4096,
               inter_size=11008,
           )),
    "baichuan_13b_chat":
    Config(name="baichuan_13b_chat",
           family="baichuan",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=40,
               hidden_size=5120,
               vocab_size=64000,
               hidden_act='silu',
               n_positions=4096,
               inter_size=13696,
           )),
    "baichuan2_13b_chat":
    Config(name="baichuan2_13b_chat",
           family="baichuan",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=40,
               hidden_size=5120,
               vocab_size=125696,
               hidden_act='silu',
               n_positions=4096,
               inter_size=13696,
           )),
    "internlm_chat_7b":
    Config(name="internlm_chat_7b",
           family="internlm",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=32,
               hidden_size=4096,
               vocab_size=103168,
               hidden_act='silu',
               n_positions=2048,
               inter_size=11008,
               bias=True,
           )),
    "internlm_chat_20b":
    Config(name="internlm_chat_20b",
           family="internlm",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=60,
               num_heads=40,
               num_kv_heads=40,
               hidden_size=5120,
               vocab_size=103168,
               hidden_act='silu',
               n_positions=4096,
               inter_size=13824,
               bias=False,
           )),
    "qwen_7b_chat":
    Config(name="qwen_7b_chat",
           family="qwen",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=151936,
               hidden_act='silu',
               n_positions=8192,
               inter_size=22016,
               bias=False,
           )),
    "qwen_14b_chat":
    Config(name="qwen_14b_chat",
           family="qwen",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=40,
               hidden_size=5120,
               vocab_size=152064,
               hidden_act='silu',
               n_positions=8192,
               inter_size=27392,
           )),
    "qwen1.5_7b_chat":
    Config(name="qwen1.5_7b_chat",
           family="qwen2",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               hidden_size=4096,
               vocab_size=151936,
               hidden_act='silu',
               n_positions=8192,
               inter_size=11008,
               bias=False,
           )),
    "qwen1.5_14b_chat":
    Config(name="qwen1.5_14b_chat",
           family="qwen2",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=40,
               num_heads=40,
               hidden_size=5120,
               vocab_size=152064,
               hidden_act='silu',
               n_positions=8192,
               inter_size=13696,
           )),
    "qwen2_7b_instruct":
    Config(name="qwen2_7b_instruct",
           family="qwen2",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=28,
               num_heads=28,
               hidden_size=3584,
               vocab_size=152064,
               hidden_act='silu',
               n_positions=32768,
               inter_size=18944,
           )),
    "mamba_2.8b":
    Config(name="mamba_2.8b",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=64,
               num_heads=1,
               hidden_size=2560,
               vocab_size=50280,
               hidden_act="silu",
               n_positions=8192,
               state_size=16,
               conv_kernel=4,
               rnn_hidden_size=5120,
               rnn_conv_dim_size=5120,
               layer_types=["recurrent"],
               use_bias=False,
           )),
    "mamba_1.4b":
    Config(name="mamba_1.4b",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=1,
               hidden_size=2048,
               vocab_size=50280,
               hidden_act="silu",
               n_positions=8192,
               state_size=16,
               conv_kernel=4,
               rnn_hidden_size=4096,
               rnn_conv_dim_size=4096,
               layer_types=["recurrent"],
               use_bias=False,
           )),
    "mamba_790m":
    Config(name="mamba_790m",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=1,
               hidden_size=1536,
               vocab_size=50280,
               hidden_act="silu",
               n_positions=8192,
               state_size=16,
               conv_kernel=4,
               rnn_hidden_size=3072,
               rnn_conv_dim_size=3072,
               layer_types=["recurrent"],
               use_bias=False,
           )),
    "mamba_370m":
    Config(name="mamba_370m",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=48,
               num_heads=1,
               hidden_size=1024,
               vocab_size=50280,
               hidden_act="silu",
               n_positions=8192,
               state_size=16,
               conv_kernel=4,
               rnn_hidden_size=2048,
               rnn_conv_dim_size=2048,
               layer_types=["recurrent"],
               use_bias=False,
           )),
    "mamba_130m":
    Config(name="mamba_130m",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=1,
               hidden_size=768,
               vocab_size=50280,
               hidden_act="silu",
               n_positions=8192,
               state_size=16,
               conv_kernel=4,
               rnn_hidden_size=1536,
               rnn_conv_dim_size=1536,
               layer_types=["recurrent"],
               use_bias=False,
           )),
    "mamba2_2.7b":
    Config(name="mamba2_2.7b",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=64,
               num_heads=1,
               hidden_size=2560,
               vocab_size=50288,
               hidden_act="silu",
               n_positions=8192,
               state_size=128,
               conv_kernel=4,
               rnn_hidden_size=5120,
               rnn_conv_dim_size=5376,
               rnn_head_size=64,
               layer_types=["recurrent"],
               use_bias=False,
               mamba_version='Mamba2',
               ssm_rmsnorm=True,
               ngroups=1,
               chunk_size=256,
           )),
    "mamba2_130m":
    Config(name="mamba2_130m",
           family="mamba",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=24,
               num_heads=1,
               hidden_size=768,
               vocab_size=50288,
               hidden_act="silu",
               n_positions=8192,
               state_size=128,
               conv_kernel=4,
               rnn_hidden_size=1536,
               rnn_conv_dim_size=1792,
               rnn_head_size=64,
               layer_types=["recurrent"],
               use_bias=False,
               mamba_version='Mamba2',
               ssm_rmsnorm=True,
               ngroups=1,
               chunk_size=256,
           )),
    "whisper_large_v3":
    Config(name="whisper_large_v3",
           family="whisper",
           benchmark_type="enc_dec",
           build_config=BuildConfig(
               max_batch_size=8,
               max_encoder_input_len=1500,
               max_input_len=1,
               max_seq_len=201,
           ),
           model_config=EncDecModelConfig(
               num_layers=32,
               num_decoder_layers=32,
               num_heads=20,
               head_size=64,
               ffn_hidden_size=5120,
               hidden_size=1280,
               vocab_size=51866,
               hidden_act="gelu",
               n_positions=448,
               n_mels=128,
           )),
    "recurrentgemma_2b":
    Config(name="recurrentgemma_2b",
           family="recurrentgemma",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=26,
               num_heads=10,
               num_kv_heads=1,
               hidden_size=2560,
               inter_size=7680,
               vocab_size=256000,
               hidden_act="gelu",
               n_positions=8192,
               position_embedding_type='rope_gpt_neox',
               rotary_pct=0.5,
               conv_kernel=4,
               state_size=1,
               layer_types=["recurrent", "recurrent", "attention"],
               rnn_hidden_size=2560,
               rnn_conv_dim_size=2560,
               logits_soft_cap=30.0,
               state_dtype="float32",
           )),
    "llama_v3_8b_instruct":
    Config(name="llama_v3_8b_instruct",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=8,
               hidden_size=4096,
               vocab_size=128256,
               hidden_act='silu',
               n_positions=8192,
               inter_size=14336,
           )),
    "mistral_7b_v0.1":
    Config(name="mistral_7b_v0.1",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=8,
               hidden_size=4096,
               vocab_size=32000,
               hidden_act='silu',
               n_positions=32768,
               inter_size=14336,
           )),
    "mixtral_8x7b":
    Config(name="mixtral_8x7b",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=8,
               hidden_size=4096,
               vocab_size=32000,
               hidden_act='swiglu',
               n_positions=32768,
               inter_size=14336,
               moe_num_experts=8,
               moe_top_k=2,
           )),
    "mixtral_8x22b_v0.1":
    Config(name="mixtral_8x22b_v0.1",
           family="llama",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=128,
               max_input_len=512,
               max_seq_len=712,
           ),
           model_config=ModelConfig(
               num_layers=56,
               num_heads=48,
               num_kv_heads=8,
               hidden_size=6144,
               vocab_size=32000,
               hidden_act='swiglu',
               n_positions=65536,
               inter_size=16384,
               moe_num_experts=8,
               moe_top_k=2,
           )),
    "phi_3_mini_4k_instruct":
    Config(name="phi_3_mini_4k_instruct",
           family="phi3",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=32,
               hidden_size=3072,
               vocab_size=32064,
               hidden_act='silu',
               n_positions=4096,
               inter_size=8192,
           )),
    "phi_3_mini_128k_instruct":
    Config(name="phi_3_mini_128k_instruct",
           family="phi3",
           benchmark_type="gpt",
           build_config=BuildConfig(
               max_batch_size=64,
               max_input_len=1024,
               max_seq_len=2048,
           ),
           model_config=ModelConfig(
               num_layers=32,
               num_heads=32,
               num_kv_heads=8,
               hidden_size=4096,
               vocab_size=128256,
               hidden_act='silu',
               n_positions=131072,
               inter_size=14336,
           )),
}


def get_allowed_models(benchmark_type=None):
    if benchmark_type is None:
        return set(_allowed_configs.keys())
    else:
        return set(i.name for i in _allowed_configs.values()
                   if i.benchmark_type == benchmark_type)


def get_build_config(model_name, return_dict=True) -> Union[Dict, BuildConfig]:
    if model_name in _allowed_configs:
        cfg = _allowed_configs[model_name].build_config
        return asdict(cfg) if return_dict else cfg
    else:
        raise KeyError(f'Unexpected model: {model_name}. Please add the model '
                       'to allowed_configs.py')


def get_model_config(model_name, return_dict=True) -> Union[Dict, ModelConfig]:
    if model_name in _allowed_configs:
        cfg = _allowed_configs[model_name].model_config
        return asdict(cfg) if return_dict else cfg
    else:
        raise KeyError(f'Unexpected model: {model_name}. Please add the model '
                       'to allowed_configs.py')


def get_model_family(model_name):
    if model_name in _allowed_configs:
        return _allowed_configs[model_name].family
    else:
        raise KeyError(f'Unexpected model: {model_name}. Please add the model '
                       'to allowed_configs.py')


def get_benchmark_type(model_name):
    if model_name in _allowed_configs:
        return _allowed_configs[model_name].benchmark_type
    else:
        raise KeyError(f'Unexpected model: {model_name}. Please add the model '
                       'to allowed_configs.py')
